df <- readRDS("../data/models/social-risk-crash-rate-data.rds")
df <- df %>%
mutate(
post_pandemic = ifelse(year < 2020, "pre", "post") # pre vs. post 2020
) %>%
mutate(post_pandemic = as.integer(post_pandemic == "post"))
df <- df %>%
mutate(
car_density_interaction = pct_vehicle * log1p(total_population),
income_vehicle_interaction = median_income * pct_vehicle
)
YEARI will treat year as a factor and one-hot encode it.
df$year <- as.factor(df$year)
year_dummies <- model.matrix(~ year - 1, data = df)
df <- cbind(df[ , !(names(df) %in% c("year"))], year_dummies)
We will remove all possible target variables and keep only one per model training.
# Choose your target variable (e.g., crash rate per 1,000 residents)
target_var <- "crash_rate_per_1000"
# Remove all target variables except selected
cols_to_remove <- grep("per_1000",
names(df),
value = TRUE)
cols_to_remove <- setdiff(cols_to_remove, target_var) # keep this column
df <- df %>% select(-all_of(cols_to_remove),)
# Create feature matrix and target vector
X <- df %>% select(-target_var, -borough, -total_population, -geoid)
y <- df[[target_var]]
glimpse(X)
Rows: 13,518
Columns: 47
$ pct_male_population <dbl> 12.460865, 11.694881, 12.229106, 13.5504…
$ pct_female_population <dbl> 10.846132, 11.629654, 11.054169, 9.83784…
$ pct_white_population <dbl> 4.1225202, 3.6815086, 2.5856754, 2.36100…
$ pct_black_population <dbl> 2.435429, 2.630197, 2.833673, 2.412624, …
$ pct_asian_population <dbl> 0.19803330, 0.14388387, 0.23376782, 0.30…
$ pct_hispanic_population <dbl> 6.411328, 6.607147, 5.982424, 6.079353, …
$ pct_foreign_born <dbl> 2.909566, 2.442189, 2.187254, 2.200420, …
$ pct_age_under_18 <dbl> 2.130762, 2.643626, 2.333613, 2.387771, …
$ pct_age_18_34 <dbl> 1.715654, 1.653705, 1.882339, 1.963363, …
$ pct_age_35_64 <dbl> 3.296112, 3.079115, 3.041014, 3.043500, …
$ pct_age_65_plus <dbl> 1.80895804, 1.78607845, 1.56929356, 1.57…
$ median_income <dbl> 58582.658, 49964.513, 68000.000, 70867.0…
$ pct_income_under_25k <dbl> 0.5445916, 0.5544325, 0.5325841, 0.51425…
$ pct_income_25k_75k <dbl> 1.0568123, 1.1376418, 0.9533662, 0.87749…
$ pct_income_75k_plus <dbl> 0.9273290, 0.8824877, 1.2379531, 1.26939…
$ pct_below_poverty <dbl> 1.9574830, 1.9510653, 1.8091597, 1.93851…
$ median_gross_rent <dbl> 1579.1133, 1524.3577, 1701.0000, 1740.00…
$ pct_owner_occupied <dbl> 1.33318303, 1.35699756, 1.58439699, 1.46…
$ pct_renter_occupied <dbl> 1.2085934, 1.2305140, 1.1518859, 1.20575…
$ pct_no_vehicle <dbl> 0.5122207, 0.6522735, 0.6118619, 0.62705…
$ pct_less_than_hs <dbl> 1.4509748, 1.1951954, 1.1302166, 1.04954…
$ pct_hs_diploma <dbl> 1.5576081, 1.4196542, 1.2704773, 1.38410…
$ pct_some_college <dbl> 0.8473540, 0.8997538, 0.7256966, 0.93484…
$ pct_associates_degree <dbl> 0.4912749, 0.3280552, 0.3923234, 0.32308…
$ pct_bachelors_degree <dbl> 0.9482748, 0.9457966, 0.8537607, 0.90234…
$ pct_graduate_degree <dbl> 0.3998749, 0.7098271, 1.1037907, 0.96734…
$ pct_in_labor_force <dbl> 3.566504, 3.430191, 3.555304, 3.595995, …
$ pct_not_in_labor_force <dbl> 3.448445, 3.326595, 3.162980, 3.106587, …
$ unemployment_rate <dbl> 15.750133, 13.478747, 10.806175, 11.5364…
$ pct_commute_short <dbl> 0.7350082, 0.4604284, 0.1585556, 0.41676…
$ pct_commute_medium <dbl> 1.7061331, 1.6882374, 1.2521824, 1.25792…
$ pct_commute_long <dbl> 2.477320, 2.685832, 3.553271, 3.737464, …
$ pct_carpool <dbl> 0.35798327, 0.31462606, 0.00000000, 0.04…
$ pct_public_transit <dbl> 2.056500, 2.540030, 2.898721, 2.366742, …
$ pct_walk <dbl> 0.37702494, 0.30695226, 0.13009688, 0.76…
$ pct_bike <dbl> 0.00000000, 0.00000000, 0.00000000, 0.00…
$ pct_work_from_home <dbl> 0.01713750, 0.04604284, 0.04675356, 0.05…
$ pct_vehicle <dbl> 2.0165122, 1.9222885, 2.1120415, 2.03409…
$ post_pandemic <int> 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0…
$ car_density_interaction <dbl> 21.871938, 20.835595, 22.817540, 22.1003…
$ income_vehicle_interaction <dbl> 118132.643, 96046.210, 143618.819, 14415…
$ year2018 <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0…
$ year2019 <dbl> 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1…
$ year2020 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0…
$ year2021 <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0…
$ year2022 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0…
$ year2023 <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0…
What This Does - Uses R’s
xgboost::xgb.cv() to evaluate each parameter set.
- Optuna (Python) handles the search space and Bayesian optimization. -
The final best parameters are applied to fit the
final_model. - Search space: Instead of
predefined grids, trial$suggest_float() and
trial$suggest_int() explore a range of values. -
Best parameters: study$best_params holds
the optimal hyperparameters.
## CONVERT TO DMATRIX
dtrain_all <- xgb.DMatrix(data = as.matrix(X), label = y)
## Start Python venv
reticulate::use_virtualenv("r-reticulate", required = TRUE)
## OPTUNA-BASED SPATIAL CV
optuna <- import("optuna")
boroughs <- unique(df$borough)
folds <- lapply(boroughs, function(b) which(df$borough != b))
# Optuna objective
objective <- function(trial) {
params <- list(
booster = "gbtree",
eta = trial$suggest_float("eta", 0.01, 0.3, log = TRUE),
max_depth = trial$suggest_int("max_depth", 3, 12),
min_child_weight = trial$suggest_int("min_child_weight", 1, 10),
subsample = trial$suggest_float("subsample", 0.5, 1.0),
colsample_bytree = trial$suggest_float("colsample_bytree", 0.5, 1.0),
gamma = trial$suggest_float("gamma", 0, 10),
lambda = trial$suggest_float("lambda", 0, 10),
alpha = trial$suggest_float("alpha", 0, 10)
)
rmse_scores <- numeric(length(folds))
for (i in seq_along(folds)) {
train_idx <- folds[[i]]
valid_idx <- setdiff(seq_len(nrow(dtrain_all)), train_idx)
dtrain <- xgb.DMatrix(data = as.matrix(X[train_idx, ]), label = y[train_idx])
dvalid <- xgb.DMatrix(data = as.matrix(X[valid_idx, ]), label = y[valid_idx])
model <- xgb.train(
params = params,
data = dtrain,
nrounds = 500,
watchlist = list(val = dvalid),
early_stopping_rounds = 20,
verbose = 0
)
rmse_scores[i] <- min(model$evaluation_log$val_rmse)
}
preds <- predict(model, as.matrix(X[valid_idx, ]))
return(Metrics::rmse(y[valid_idx], preds))
}
# Run Optuna study
set.seed(2025)
study <- optuna$create_study(direction = "minimize")
study$optimize(objective, n_trials = 50)
best_params <- study$best_params
print(best_params)
$eta
[1] 0.03632745
$max_depth
[1] 4
$min_child_weight
[1] 2
$subsample
[1] 0.8879885
$colsample_bytree
[1] 0.7059848
$gamma
[1] 4.097305
$lambda
[1] 6.082163
$alpha
[1] 9.197171
# Set seed
set.seed(2025)
# Split by index
train_index <- createDataPartition(y, p = 0.8, list = FALSE)
X_train <- X[train_index, ]
y_train <- y[train_index]
X_test <- X[-train_index, ]
y_test <- y[-train_index]
# Convert to xgb.DMatrix
dtrain <- xgb.DMatrix(data = as.matrix(X_train), label = y_train)
dtest <- xgb.DMatrix(data = as.matrix(X_test), label = y_test)
# Set seed
set.seed(2025)
# Training with parallel processing
final_model <- xgb.train(
params = list(
eta = best_params$eta,
max_depth = best_params$max_depth,
gamma = best_params$gamma,
colsample_bytree = best_params$colsample_bytree,
min_child_weight = best_params$min_child_weight,
subsample = best_params$subsample,
objective = "reg:squarederror",
eval_metric = "rmse"
),
data = dtrain,
nrounds = 1000,
watchlist = list(train = dtrain, test = dtest),
early_stopping_rounds = 20,
verbose = 1,
nthread = detectCores() - 1
)
[1] train-rmse:2.024689 test-rmse:2.001502
Multiple eval metrics are present. Will use test_rmse for early stopping.
Will train until test_rmse hasn't improved in 20 rounds.
[2] train-rmse:1.998355 test-rmse:1.975847
[3] train-rmse:1.974319 test-rmse:1.953049
[4] train-rmse:1.952721 test-rmse:1.932999
[5] train-rmse:1.930054 test-rmse:1.913089
[6] train-rmse:1.908801 test-rmse:1.895099
[7] train-rmse:1.891739 test-rmse:1.878780
[8] train-rmse:1.871690 test-rmse:1.861159
[9] train-rmse:1.853396 test-rmse:1.845040
[10] train-rmse:1.835669 test-rmse:1.828367
[11] train-rmse:1.821744 test-rmse:1.816867
[12] train-rmse:1.805291 test-rmse:1.802956
[13] train-rmse:1.789961 test-rmse:1.788833
[14] train-rmse:1.774891 test-rmse:1.776486
[15] train-rmse:1.761938 test-rmse:1.764305
[16] train-rmse:1.748162 test-rmse:1.751820
[17] train-rmse:1.735951 test-rmse:1.741813
[18] train-rmse:1.725182 test-rmse:1.732486
[19] train-rmse:1.715346 test-rmse:1.726037
[20] train-rmse:1.703982 test-rmse:1.716642
[21] train-rmse:1.693773 test-rmse:1.710661
[22] train-rmse:1.683067 test-rmse:1.702340
[23] train-rmse:1.674536 test-rmse:1.695314
[24] train-rmse:1.665630 test-rmse:1.688261
[25] train-rmse:1.656567 test-rmse:1.681153
[26] train-rmse:1.648420 test-rmse:1.675137
[27] train-rmse:1.640781 test-rmse:1.669599
[28] train-rmse:1.632523 test-rmse:1.664210
[29] train-rmse:1.627208 test-rmse:1.659251
[30] train-rmse:1.621461 test-rmse:1.654814
[31] train-rmse:1.615179 test-rmse:1.649371
[32] train-rmse:1.608203 test-rmse:1.644481
[33] train-rmse:1.601043 test-rmse:1.639614
[34] train-rmse:1.595589 test-rmse:1.635750
[35] train-rmse:1.590173 test-rmse:1.630972
[36] train-rmse:1.584897 test-rmse:1.627317
[37] train-rmse:1.579926 test-rmse:1.623348
[38] train-rmse:1.576162 test-rmse:1.621127
[39] train-rmse:1.571841 test-rmse:1.618717
[40] train-rmse:1.566880 test-rmse:1.615505
[41] train-rmse:1.563612 test-rmse:1.611808
[42] train-rmse:1.558898 test-rmse:1.607830
[43] train-rmse:1.554437 test-rmse:1.605901
[44] train-rmse:1.550984 test-rmse:1.604009
[45] train-rmse:1.546700 test-rmse:1.599691
[46] train-rmse:1.542524 test-rmse:1.596537
[47] train-rmse:1.539138 test-rmse:1.593237
[48] train-rmse:1.536203 test-rmse:1.591365
[49] train-rmse:1.533314 test-rmse:1.589556
[50] train-rmse:1.529964 test-rmse:1.587873
[51] train-rmse:1.527428 test-rmse:1.586551
[52] train-rmse:1.524310 test-rmse:1.584876
[53] train-rmse:1.521837 test-rmse:1.582542
[54] train-rmse:1.517478 test-rmse:1.581230
[55] train-rmse:1.514896 test-rmse:1.579555
[56] train-rmse:1.512605 test-rmse:1.578785
[57] train-rmse:1.509587 test-rmse:1.577967
[58] train-rmse:1.505918 test-rmse:1.576259
[59] train-rmse:1.503651 test-rmse:1.574108
[60] train-rmse:1.500156 test-rmse:1.572615
[61] train-rmse:1.496597 test-rmse:1.571579
[62] train-rmse:1.494282 test-rmse:1.570307
[63] train-rmse:1.491739 test-rmse:1.568419
[64] train-rmse:1.489145 test-rmse:1.567036
[65] train-rmse:1.486202 test-rmse:1.566324
[66] train-rmse:1.483291 test-rmse:1.565861
[67] train-rmse:1.480517 test-rmse:1.565620
[68] train-rmse:1.478047 test-rmse:1.565305
[69] train-rmse:1.474847 test-rmse:1.564701
[70] train-rmse:1.471981 test-rmse:1.564082
[71] train-rmse:1.470111 test-rmse:1.562768
[72] train-rmse:1.468874 test-rmse:1.561247
[73] train-rmse:1.467763 test-rmse:1.560164
[74] train-rmse:1.465884 test-rmse:1.559562
[75] train-rmse:1.463966 test-rmse:1.559302
[76] train-rmse:1.462603 test-rmse:1.557923
[77] train-rmse:1.460169 test-rmse:1.557860
[78] train-rmse:1.458633 test-rmse:1.557181
[79] train-rmse:1.456757 test-rmse:1.555342
[80] train-rmse:1.455553 test-rmse:1.554737
[81] train-rmse:1.454042 test-rmse:1.553398
[82] train-rmse:1.451533 test-rmse:1.551736
[83] train-rmse:1.449614 test-rmse:1.550860
[84] train-rmse:1.447784 test-rmse:1.549669
[85] train-rmse:1.445705 test-rmse:1.548429
[86] train-rmse:1.443257 test-rmse:1.547772
[87] train-rmse:1.441915 test-rmse:1.546992
[88] train-rmse:1.440558 test-rmse:1.546702
[89] train-rmse:1.438290 test-rmse:1.545958
[90] train-rmse:1.436691 test-rmse:1.545385
[91] train-rmse:1.434475 test-rmse:1.544466
[92] train-rmse:1.432608 test-rmse:1.543676
[93] train-rmse:1.430924 test-rmse:1.543293
[94] train-rmse:1.428885 test-rmse:1.543319
[95] train-rmse:1.426027 test-rmse:1.542649
[96] train-rmse:1.424238 test-rmse:1.542098
[97] train-rmse:1.423567 test-rmse:1.541733
[98] train-rmse:1.422403 test-rmse:1.541594
[99] train-rmse:1.421314 test-rmse:1.540684
[100] train-rmse:1.419551 test-rmse:1.540291
[101] train-rmse:1.418234 test-rmse:1.539756
[102] train-rmse:1.416881 test-rmse:1.539631
[103] train-rmse:1.414631 test-rmse:1.539938
[104] train-rmse:1.412507 test-rmse:1.539227
[105] train-rmse:1.410264 test-rmse:1.539756
[106] train-rmse:1.408972 test-rmse:1.539833
[107] train-rmse:1.406176 test-rmse:1.539411
[108] train-rmse:1.405678 test-rmse:1.539225
[109] train-rmse:1.404553 test-rmse:1.538956
[110] train-rmse:1.402732 test-rmse:1.539053
[111] train-rmse:1.400331 test-rmse:1.539047
[112] train-rmse:1.398743 test-rmse:1.539190
[113] train-rmse:1.396857 test-rmse:1.538258
[114] train-rmse:1.395947 test-rmse:1.538006
[115] train-rmse:1.394972 test-rmse:1.537657
[116] train-rmse:1.393948 test-rmse:1.537832
[117] train-rmse:1.392602 test-rmse:1.537377
[118] train-rmse:1.391511 test-rmse:1.536610
[119] train-rmse:1.390013 test-rmse:1.535906
[120] train-rmse:1.387949 test-rmse:1.535541
[121] train-rmse:1.387102 test-rmse:1.534979
[122] train-rmse:1.386194 test-rmse:1.534821
[123] train-rmse:1.385178 test-rmse:1.533674
[124] train-rmse:1.383140 test-rmse:1.533575
[125] train-rmse:1.381667 test-rmse:1.533506
[126] train-rmse:1.379677 test-rmse:1.532236
[127] train-rmse:1.378600 test-rmse:1.530953
[128] train-rmse:1.376997 test-rmse:1.530241
[129] train-rmse:1.376078 test-rmse:1.529703
[130] train-rmse:1.374497 test-rmse:1.529502
[131] train-rmse:1.373893 test-rmse:1.529210
[132] train-rmse:1.372259 test-rmse:1.528860
[133] train-rmse:1.371521 test-rmse:1.528506
[134] train-rmse:1.369885 test-rmse:1.528109
[135] train-rmse:1.369005 test-rmse:1.527645
[136] train-rmse:1.368592 test-rmse:1.527627
[137] train-rmse:1.367367 test-rmse:1.527402
[138] train-rmse:1.365340 test-rmse:1.527120
[139] train-rmse:1.363783 test-rmse:1.526280
[140] train-rmse:1.362854 test-rmse:1.525762
[141] train-rmse:1.362326 test-rmse:1.525748
[142] train-rmse:1.361611 test-rmse:1.526153
[143] train-rmse:1.359682 test-rmse:1.525752
[144] train-rmse:1.358349 test-rmse:1.525485
[145] train-rmse:1.356420 test-rmse:1.524857
[146] train-rmse:1.354693 test-rmse:1.524077
[147] train-rmse:1.353734 test-rmse:1.524079
[148] train-rmse:1.352863 test-rmse:1.524114
[149] train-rmse:1.351948 test-rmse:1.523477
[150] train-rmse:1.349867 test-rmse:1.523384
[151] train-rmse:1.349164 test-rmse:1.523430
[152] train-rmse:1.347177 test-rmse:1.522768
[153] train-rmse:1.345161 test-rmse:1.521724
[154] train-rmse:1.343847 test-rmse:1.521730
[155] train-rmse:1.342894 test-rmse:1.520900
[156] train-rmse:1.341756 test-rmse:1.521296
[157] train-rmse:1.341067 test-rmse:1.520886
[158] train-rmse:1.340276 test-rmse:1.520838
[159] train-rmse:1.338684 test-rmse:1.519834
[160] train-rmse:1.337265 test-rmse:1.520292
[161] train-rmse:1.336634 test-rmse:1.519918
[162] train-rmse:1.335069 test-rmse:1.519312
[163] train-rmse:1.334108 test-rmse:1.519483
[164] train-rmse:1.332502 test-rmse:1.518497
[165] train-rmse:1.332140 test-rmse:1.518564
[166] train-rmse:1.331068 test-rmse:1.518146
[167] train-rmse:1.329852 test-rmse:1.518635
[168] train-rmse:1.328766 test-rmse:1.518644
[169] train-rmse:1.328031 test-rmse:1.518370
[170] train-rmse:1.325819 test-rmse:1.517216
[171] train-rmse:1.325244 test-rmse:1.516752
[172] train-rmse:1.323273 test-rmse:1.516412
[173] train-rmse:1.322351 test-rmse:1.516244
[174] train-rmse:1.321004 test-rmse:1.515957
[175] train-rmse:1.319940 test-rmse:1.515938
[176] train-rmse:1.318055 test-rmse:1.515518
[177] train-rmse:1.317552 test-rmse:1.515229
[178] train-rmse:1.316843 test-rmse:1.515641
[179] train-rmse:1.315283 test-rmse:1.515345
[180] train-rmse:1.313461 test-rmse:1.515309
[181] train-rmse:1.312013 test-rmse:1.515195
[182] train-rmse:1.310458 test-rmse:1.515029
[183] train-rmse:1.309118 test-rmse:1.515510
[184] train-rmse:1.307634 test-rmse:1.514602
[185] train-rmse:1.305562 test-rmse:1.514370
[186] train-rmse:1.304643 test-rmse:1.514094
[187] train-rmse:1.303823 test-rmse:1.514313
[188] train-rmse:1.303135 test-rmse:1.513811
[189] train-rmse:1.301987 test-rmse:1.513415
[190] train-rmse:1.301115 test-rmse:1.513052
[191] train-rmse:1.300090 test-rmse:1.512693
[192] train-rmse:1.299628 test-rmse:1.512415
[193] train-rmse:1.298428 test-rmse:1.512194
[194] train-rmse:1.297476 test-rmse:1.511815
[195] train-rmse:1.296244 test-rmse:1.511766
[196] train-rmse:1.295085 test-rmse:1.510876
[197] train-rmse:1.294221 test-rmse:1.510858
[198] train-rmse:1.292849 test-rmse:1.511246
[199] train-rmse:1.290329 test-rmse:1.511142
[200] train-rmse:1.289340 test-rmse:1.510961
[201] train-rmse:1.288004 test-rmse:1.510840
[202] train-rmse:1.287206 test-rmse:1.510956
[203] train-rmse:1.286197 test-rmse:1.510766
[204] train-rmse:1.285647 test-rmse:1.510482
[205] train-rmse:1.284056 test-rmse:1.510332
[206] train-rmse:1.283630 test-rmse:1.510093
[207] train-rmse:1.282816 test-rmse:1.510396
[208] train-rmse:1.281604 test-rmse:1.510221
[209] train-rmse:1.280466 test-rmse:1.509975
[210] train-rmse:1.280214 test-rmse:1.509292
[211] train-rmse:1.279850 test-rmse:1.508799
[212] train-rmse:1.278932 test-rmse:1.508748
[213] train-rmse:1.278448 test-rmse:1.508481
[214] train-rmse:1.277426 test-rmse:1.508186
[215] train-rmse:1.275234 test-rmse:1.508494
[216] train-rmse:1.273992 test-rmse:1.508632
[217] train-rmse:1.273229 test-rmse:1.508357
[218] train-rmse:1.272960 test-rmse:1.508903
[219] train-rmse:1.272684 test-rmse:1.508306
[220] train-rmse:1.271494 test-rmse:1.507825
[221] train-rmse:1.269094 test-rmse:1.508296
[222] train-rmse:1.267877 test-rmse:1.507658
[223] train-rmse:1.267245 test-rmse:1.507634
[224] train-rmse:1.266839 test-rmse:1.506810
[225] train-rmse:1.266145 test-rmse:1.506376
[226] train-rmse:1.264481 test-rmse:1.505032
[227] train-rmse:1.263460 test-rmse:1.504897
[228] train-rmse:1.261385 test-rmse:1.504549
[229] train-rmse:1.260521 test-rmse:1.504311
[230] train-rmse:1.259975 test-rmse:1.504298
[231] train-rmse:1.258607 test-rmse:1.503251
[232] train-rmse:1.258242 test-rmse:1.503352
[233] train-rmse:1.257318 test-rmse:1.503161
[234] train-rmse:1.256750 test-rmse:1.502992
[235] train-rmse:1.256474 test-rmse:1.502954
[236] train-rmse:1.254916 test-rmse:1.503106
[237] train-rmse:1.253671 test-rmse:1.502871
[238] train-rmse:1.253077 test-rmse:1.502255
[239] train-rmse:1.251273 test-rmse:1.502482
[240] train-rmse:1.250030 test-rmse:1.502108
[241] train-rmse:1.248961 test-rmse:1.502084
[242] train-rmse:1.247390 test-rmse:1.502092
[243] train-rmse:1.245923 test-rmse:1.502468
[244] train-rmse:1.244874 test-rmse:1.502494
[245] train-rmse:1.243349 test-rmse:1.502799
[246] train-rmse:1.242638 test-rmse:1.502397
[247] train-rmse:1.241796 test-rmse:1.502076
[248] train-rmse:1.240354 test-rmse:1.501556
[249] train-rmse:1.238875 test-rmse:1.501306
[250] train-rmse:1.237827 test-rmse:1.500955
[251] train-rmse:1.236400 test-rmse:1.501387
[252] train-rmse:1.235681 test-rmse:1.501140
[253] train-rmse:1.235121 test-rmse:1.500464
[254] train-rmse:1.234814 test-rmse:1.500631
[255] train-rmse:1.234268 test-rmse:1.500691
[256] train-rmse:1.232966 test-rmse:1.500681
[257] train-rmse:1.232066 test-rmse:1.500582
[258] train-rmse:1.230420 test-rmse:1.499850
[259] train-rmse:1.229120 test-rmse:1.499926
[260] train-rmse:1.228644 test-rmse:1.499728
[261] train-rmse:1.227464 test-rmse:1.498909
[262] train-rmse:1.226840 test-rmse:1.499525
[263] train-rmse:1.225092 test-rmse:1.499614
[264] train-rmse:1.223459 test-rmse:1.499565
[265] train-rmse:1.222515 test-rmse:1.499307
[266] train-rmse:1.221418 test-rmse:1.498816
[267] train-rmse:1.219985 test-rmse:1.498760
[268] train-rmse:1.218894 test-rmse:1.498425
[269] train-rmse:1.217882 test-rmse:1.498828
[270] train-rmse:1.217139 test-rmse:1.498771
[271] train-rmse:1.216292 test-rmse:1.498499
[272] train-rmse:1.215450 test-rmse:1.498877
[273] train-rmse:1.214858 test-rmse:1.498614
[274] train-rmse:1.214581 test-rmse:1.498343
[275] train-rmse:1.213839 test-rmse:1.498393
[276] train-rmse:1.213249 test-rmse:1.498096
[277] train-rmse:1.212644 test-rmse:1.498089
[278] train-rmse:1.211809 test-rmse:1.497718
[279] train-rmse:1.210446 test-rmse:1.497280
[280] train-rmse:1.209677 test-rmse:1.497491
[281] train-rmse:1.208185 test-rmse:1.497479
[282] train-rmse:1.207537 test-rmse:1.496418
[283] train-rmse:1.206724 test-rmse:1.496101
[284] train-rmse:1.205690 test-rmse:1.496282
[285] train-rmse:1.204925 test-rmse:1.496391
[286] train-rmse:1.203675 test-rmse:1.496472
[287] train-rmse:1.203177 test-rmse:1.495836
[288] train-rmse:1.202126 test-rmse:1.495650
[289] train-rmse:1.201412 test-rmse:1.495021
[290] train-rmse:1.200453 test-rmse:1.494902
[291] train-rmse:1.200102 test-rmse:1.494974
[292] train-rmse:1.199601 test-rmse:1.494852
[293] train-rmse:1.198060 test-rmse:1.494143
[294] train-rmse:1.197222 test-rmse:1.494127
[295] train-rmse:1.196683 test-rmse:1.494231
[296] train-rmse:1.195828 test-rmse:1.494283
[297] train-rmse:1.195332 test-rmse:1.494274
[298] train-rmse:1.195015 test-rmse:1.494039
[299] train-rmse:1.194363 test-rmse:1.494159
[300] train-rmse:1.192917 test-rmse:1.494228
[301] train-rmse:1.192444 test-rmse:1.494134
[302] train-rmse:1.191284 test-rmse:1.494156
[303] train-rmse:1.190979 test-rmse:1.493899
[304] train-rmse:1.190745 test-rmse:1.493904
[305] train-rmse:1.190359 test-rmse:1.493914
[306] train-rmse:1.190150 test-rmse:1.493813
[307] train-rmse:1.189688 test-rmse:1.493364
[308] train-rmse:1.189204 test-rmse:1.493641
[309] train-rmse:1.188468 test-rmse:1.493777
[310] train-rmse:1.188206 test-rmse:1.493624
[311] train-rmse:1.186628 test-rmse:1.493288
[312] train-rmse:1.185289 test-rmse:1.493302
[313] train-rmse:1.184616 test-rmse:1.493028
[314] train-rmse:1.183890 test-rmse:1.492789
[315] train-rmse:1.183583 test-rmse:1.492818
[316] train-rmse:1.182808 test-rmse:1.492770
[317] train-rmse:1.181566 test-rmse:1.492502
[318] train-rmse:1.180222 test-rmse:1.492560
[319] train-rmse:1.178591 test-rmse:1.493058
[320] train-rmse:1.177953 test-rmse:1.492492
[321] train-rmse:1.177245 test-rmse:1.492405
[322] train-rmse:1.176543 test-rmse:1.492441
[323] train-rmse:1.175385 test-rmse:1.493005
[324] train-rmse:1.174668 test-rmse:1.492953
[325] train-rmse:1.174383 test-rmse:1.492965
[326] train-rmse:1.173127 test-rmse:1.492967
[327] train-rmse:1.172500 test-rmse:1.492678
[328] train-rmse:1.171769 test-rmse:1.492769
[329] train-rmse:1.171362 test-rmse:1.492801
[330] train-rmse:1.170012 test-rmse:1.492165
[331] train-rmse:1.169306 test-rmse:1.492075
[332] train-rmse:1.168296 test-rmse:1.491973
[333] train-rmse:1.167846 test-rmse:1.491976
[334] train-rmse:1.167057 test-rmse:1.492050
[335] train-rmse:1.166430 test-rmse:1.491773
[336] train-rmse:1.165876 test-rmse:1.491870
[337] train-rmse:1.164803 test-rmse:1.491136
[338] train-rmse:1.164518 test-rmse:1.491153
[339] train-rmse:1.163673 test-rmse:1.490802
[340] train-rmse:1.163189 test-rmse:1.491108
[341] train-rmse:1.162067 test-rmse:1.490937
[342] train-rmse:1.161537 test-rmse:1.490558
[343] train-rmse:1.161237 test-rmse:1.490392
[344] train-rmse:1.160393 test-rmse:1.490689
[345] train-rmse:1.160178 test-rmse:1.490589
[346] train-rmse:1.159091 test-rmse:1.490629
[347] train-rmse:1.157898 test-rmse:1.490452
[348] train-rmse:1.157138 test-rmse:1.490365
[349] train-rmse:1.156410 test-rmse:1.490741
[350] train-rmse:1.155877 test-rmse:1.490953
[351] train-rmse:1.155536 test-rmse:1.491106
[352] train-rmse:1.154819 test-rmse:1.491003
[353] train-rmse:1.154440 test-rmse:1.490788
[354] train-rmse:1.153897 test-rmse:1.490694
[355] train-rmse:1.152675 test-rmse:1.490351
[356] train-rmse:1.151644 test-rmse:1.490470
[357] train-rmse:1.151174 test-rmse:1.490329
[358] train-rmse:1.150587 test-rmse:1.490575
[359] train-rmse:1.150088 test-rmse:1.491189
[360] train-rmse:1.149410 test-rmse:1.491304
[361] train-rmse:1.149056 test-rmse:1.491135
[362] train-rmse:1.147973 test-rmse:1.491083
[363] train-rmse:1.147359 test-rmse:1.490830
[364] train-rmse:1.146764 test-rmse:1.490997
[365] train-rmse:1.146508 test-rmse:1.490921
[366] train-rmse:1.146336 test-rmse:1.490720
[367] train-rmse:1.144954 test-rmse:1.490063
[368] train-rmse:1.143758 test-rmse:1.490015
[369] train-rmse:1.143122 test-rmse:1.490120
[370] train-rmse:1.142812 test-rmse:1.490124
[371] train-rmse:1.142312 test-rmse:1.489949
[372] train-rmse:1.141067 test-rmse:1.490203
[373] train-rmse:1.140787 test-rmse:1.490122
[374] train-rmse:1.140185 test-rmse:1.490181
[375] train-rmse:1.139396 test-rmse:1.489990
[376] train-rmse:1.139169 test-rmse:1.489939
[377] train-rmse:1.138669 test-rmse:1.489318
[378] train-rmse:1.138146 test-rmse:1.489294
[379] train-rmse:1.136919 test-rmse:1.489769
[380] train-rmse:1.135952 test-rmse:1.489362
[381] train-rmse:1.135390 test-rmse:1.489275
[382] train-rmse:1.134624 test-rmse:1.489626
[383] train-rmse:1.133474 test-rmse:1.489248
[384] train-rmse:1.133210 test-rmse:1.489120
[385] train-rmse:1.132632 test-rmse:1.488904
[386] train-rmse:1.132148 test-rmse:1.488715
[387] train-rmse:1.131356 test-rmse:1.488473
[388] train-rmse:1.131164 test-rmse:1.488412
[389] train-rmse:1.130669 test-rmse:1.488147
[390] train-rmse:1.129956 test-rmse:1.487972
[391] train-rmse:1.129246 test-rmse:1.487718
[392] train-rmse:1.128394 test-rmse:1.487495
[393] train-rmse:1.127913 test-rmse:1.487526
[394] train-rmse:1.126581 test-rmse:1.486545
[395] train-rmse:1.125568 test-rmse:1.486624
[396] train-rmse:1.124985 test-rmse:1.486401
[397] train-rmse:1.124178 test-rmse:1.485516
[398] train-rmse:1.123257 test-rmse:1.485276
[399] train-rmse:1.122939 test-rmse:1.485114
[400] train-rmse:1.122597 test-rmse:1.485119
[401] train-rmse:1.122093 test-rmse:1.485188
[402] train-rmse:1.121594 test-rmse:1.485019
[403] train-rmse:1.120478 test-rmse:1.484634
[404] train-rmse:1.119523 test-rmse:1.484862
[405] train-rmse:1.119322 test-rmse:1.484737
[406] train-rmse:1.119144 test-rmse:1.484662
[407] train-rmse:1.117965 test-rmse:1.484264
[408] train-rmse:1.116427 test-rmse:1.483442
[409] train-rmse:1.115212 test-rmse:1.483246
[410] train-rmse:1.114522 test-rmse:1.483080
[411] train-rmse:1.114036 test-rmse:1.483384
[412] train-rmse:1.112574 test-rmse:1.483917
[413] train-rmse:1.111843 test-rmse:1.483857
[414] train-rmse:1.111110 test-rmse:1.483775
[415] train-rmse:1.110860 test-rmse:1.483566
[416] train-rmse:1.109916 test-rmse:1.483141
[417] train-rmse:1.109377 test-rmse:1.483601
[418] train-rmse:1.109026 test-rmse:1.483764
[419] train-rmse:1.108832 test-rmse:1.483687
[420] train-rmse:1.107931 test-rmse:1.483368
[421] train-rmse:1.107425 test-rmse:1.483507
[422] train-rmse:1.106677 test-rmse:1.483907
[423] train-rmse:1.106439 test-rmse:1.484152
[424] train-rmse:1.106228 test-rmse:1.484004
[425] train-rmse:1.105633 test-rmse:1.483829
[426] train-rmse:1.105274 test-rmse:1.483925
[427] train-rmse:1.103790 test-rmse:1.483335
[428] train-rmse:1.102676 test-rmse:1.483325
[429] train-rmse:1.101779 test-rmse:1.483764
[430] train-rmse:1.101503 test-rmse:1.483741
Stopping. Best iteration:
[410] train-rmse:1.114522 test-rmse:1.483080
# Create directory if it doesn't exist
if (!dir.exists("../data/models")) {
dir.create("../data/models", recursive = TRUE)
}
# Save the final XGBoost model
saveRDS(final_model, file = "../data/models/xgb_model.rds")
# Save the best parameters
saveRDS(best_params, file = "../data/models/xgb_best_params.rds")
cat("Model and parameters saved to ../data/models/")
Model and parameters saved to ../data/models/
library(Metrics)
library(ggplot2)
library(dplyr)
set.seed(2025)
# Predict on test set
preds <- predict(final_model, as.matrix(X_test))
# --- Metrics ---
rmse <- sqrt(mean((y_test - preds)^2))
mae <- mean(abs(y_test - preds))
mape <- mean(abs((y_test - preds) / y_test)) * 100
r2 <- 1 - (sum((y_test - preds)^2) / sum((y_test - mean(y_test))^2))
cat("Model Evaluation Metrics:\n")
Model Evaluation Metrics:
cat(" RMSE:", rmse, "\n")
RMSE: 1.48308
cat(" MAE :", mae, "\n")
MAE : 0.7919562
cat(" MAPE:", mape, "%\n")
MAPE: Inf %
cat(" R² :", r2, "\n\n")
R² : 0.321709
# --- Residuals ---
residuals <- y_test - preds
residual_df <- data.frame(
actual = y_test,
predicted = preds,
residuals = residuals
)
# --- Plot: Predicted vs Actual ---
p1 <- residual_df %>%
ggplot(aes(x = actual, y = predicted)) +
geom_point(alpha = 0.5) +
geom_abline(slope = 1, intercept = 0, color = "red") +
theme_minimal() +
labs(title = "Predicted vs Actual Crash Rates",
x = "Actual",
y = "Predicted")
# --- Plot: Residuals vs Predicted ---
p2 <- residual_df %>%
ggplot(aes(x = predicted, y = residuals)) +
geom_point(alpha = 0.5, color = "blue") +
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
theme_minimal() +
labs(title = "Residuals vs Predicted",
x = "Predicted",
y = "Residuals")
# --- Plot: Residual Density ---
p3 <- residual_df %>%
ggplot(aes(x = residuals)) +
geom_histogram(aes(y = ..density..), bins = 30, fill = "skyblue", alpha = 0.7) +
geom_density(color = "red") +
theme_minimal() +
labs(title = "Residual Distribution",
x = "Residuals",
y = "Density")
# Print plots
print(p1)
print(p2)
print(p3)
ggsave("../report/plots/predicted_vs_actual_values_plot.png", p1, width = 10, height = 6, dpi = 300)
ggsave("../report/plots/resisuals_vs_predicted_values_plot.png", p2, width = 10, height = 6, dpi = 300)
ggsave("../report/plots/residual_density_plot.png", p3, width = 10, height = 6, dpi = 300)
# Compute SHAP values
shap_values <- shap.values(xgb_model = final_model, X_train = as.matrix(X_train))
shap_long <- shap.prep(shap_contrib = shap_values$shap_score, X_train = as.matrix(X_train))
# SHAP summary plot
print(shap.plot.summary(shap_long))
if (!dir.exists("../report/plots")) {
dir.create("../report/plots")
}
png("../report/plots/shap_summary_plot.png", width = 10, height = 6, dpi = 300)
Error in check.options(new, name.opt = ".X11.Options", envir = .X11env) :
c("invalid argument name ‘dpi’ in 'png(\"../report/plots/shap_summary_plot.png\", width = 10, height = 6, '", "invalid argument name ‘dpi’ in ' dpi = 300)'")
xgb.plot.tree(model = final_model, trees = 0)
xgb.plot.tree(model = final_model, trees = 1)
xgb.plot.tree(model = final_model, trees = 2)
xgb.plot.multi.trees(model = final_model)
# ============================================================
# Additional Model Diagnostics and Deeper Analysis
# ============================================================
library(ggplot2)
library(dplyr)
library(pdp) # For Partial Dependence Plots
library(DALEX) # For model explainability
library(ggthemes)
library(sf)
# ---------------------------
# 1. SHAP Dependence and Interaction Plots
# ---------------------------
message("\nGenerating SHAP dependence and interaction plots...")
# Assuming shap_values and shap_long are already computed
# (if not, recompute them using iml or SHAPforxgboost packages)
# Top feature by SHAP importance
top_feature <- shap_long %>%
as_tibble() %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1) %>%
pull(variable)
# Dependence plot for top feature
shap.plot.dependence(data_long = shap_long, x = top_feature, color_feature = top_feature)
# Interaction values
shap_interaction_values <- predict(
final_model,
as.matrix(X_train),
predinteraction = TRUE
)
# shap_interaction_values will be a 3D array: [n_samples, n_features, n_features]
dim(shap_interaction_values)
[1] 10816 48 48
# ---------------------------
# 2. Residual Plots and Mapping
# ---------------------------
message("\nComputing residuals and creating residual plots...")
preds <- predict(final_model, as.matrix(X_test))
residuals <- y_test - preds
residual_df <- data.frame(
observed = y_test,
predicted = preds,
residual = residuals
)
# Predicted vs Observed
ggplot(residual_df, aes(x = predicted, y = observed)) +
geom_point(alpha = 0.6) +
geom_abline(slope = 1, intercept = 0, color = "red") +
theme_minimal() +
labs(title = "Predicted vs. Observed Crash Rates",
x = "Predicted Crash Rate per 1,000",
y = "Observed Crash Rate per 1,000")
# Residual Histogram
ggplot(residual_df, aes(x = residual)) +
geom_histogram(binwidth = 0.2, fill = "steelblue", color = "white") +
theme_minimal() +
labs(title = "Residual Distribution", x = "Residuals", y = "Count")
# ---------------------------
# 3. Partial Dependence Plots (PDP)
# ---------------------------
message("\nGenerating Partial Dependence Plots...")
top_features <- shap_long %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1:10) %>%
pull(variable)
for (f in top_features) {
pd <- partial(final_model, pred.var = f, train = as.matrix(X_train), grid.resolution = 30)
plot(pd, main = paste("Partial Dependence of", f))
}